import numpy
from copy      import copy
from operator  import mul
from .utils     import CfList
from .functions import CHUNKSIZE, FM_THRESHOLD
from .partition import Partition

# ====================================================================
#
# PartitionArray object
#
# ====================================================================

class PartitionArray(CfList):
    '''

An N-dimensional partition array.

'''

    def __init__(self, array, **kwargs):
        '''

**Initialization**

:Parameters:

    array : list
        A (possibly nested) list of Partition objects.

    direction : bool or dict
        The direction of each dimension of the data array. It is a
        boolean if the data array is a scalar array.

#    dimensions : list
#        The identities of the dimensions of the data array. If the
#        data array is a scalar array then it is an empty list.

    dimensions : list
        The identities of the partition dimensions of the partition
        array. If partition array is a scalar array then it is an
        empty list.

**Examples**

>>> pa = PartitionArray(
        [Partition(location   = [(0,n) for n in shape],
                   shape      = shape[:],
                   dimensions = dimensions[:],
                   direction  = copy(direction),
                   Units      = units.copy(),
                   part       = [],
                   data       = data)
         ],
#        dimensions  = dimensions,
        direction   = direction,
        dimensions = [])

'''
        super(PartitionArray, self).__init__()

        self._list = array

        for attr, value in kwargs.iteritems():
            setattr(self, attr, value)
            
#        self.hardmask = hardmask
    #--- End: def
    
    def __repr__(self):
        '''
x.__repr__() <==> repr(x)

'''
        return '<CF %s: %d partition dimensions>' % (self.__class__.__name__, 
                                                     len(self.dimensions))
    #--- End: def

    def __nonzero__(self):
        '''
x.__nonzero__() <==> bool(x)

'''
        return True
    #--- End: def

    def __getitem__(self, indices):
        '''
x.__getitem__(indices) <==> x[indices]

'''
        if isinstance(indices, tuple):
            out = self._list
            for i in indices:
                out = out[i]
        else:
            out = self._list[indices]

        if isinstance(out, Partition):
            return out

        # out is a list so return a PartitionArray
        return type(self)(out, dimensions=self.dimensions,
                          direction=self.direction)
#                          dimensions=self.dimensions,
    #--- End: def

    def __setitem__(self, indices, value):
        '''
x.__setitem__(indices, y) <==> x[indices]=y

'''
        if isinstance(indices, tuple):
            s = self._list
            for i in indices[:-1]:
                s = s[i]

            s[indices[-1]] = value
        else:
            self._list[indices] = value
    #--- End: def

    def __str__(self):
        '''
x.__str__() <==> str(x)

'''
        out = []
        for partition in self.flat():
            out.append(str(partition))

        return '\n'.join(out)
    #--- End: def

    # ----------------------------------------------------------------
    # Attribute: _is_1d (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def _is_1d(self):
        '''

True if the partition array is one dimensionsal in partition space.

**Examples**

>>> pa._list
[<cf.partition.Partition object at 0x4dux938>]
>>> pa._is_1d
True

>>> pa._list
[[<cf.partition.Partition object at 0x4dbc938>,
  <cf.partition.Partition object at 0x4dbc9b0>],
 [<cf.partition.Partition object at 0x4dbcde8>,
  <cf.partition.Partition object at 0x4db9758>]]
>>> pa._is_1d
False

>>> pa._list
[[<cf.partition.Partition object at 0x4dyc038>]]
>>> pa._is_1d
False

'''
        return isinstance(self[0], Partition)
    #--- End: if

#    # ----------------------------------------------------------------
#    # Attribute: dtype (can't set or delete)
#    # ----------------------------------------------------------------
#    @property
#    def dtype(self):
#        '''
#
#The numpy data type of the master array.
#
#This is the data type with the smallest size and smallest scalar kind
#to which all data array partitions may be safely cast without loss of
#information. For example, if the partitions have data types 'int64'
#and 'float32' then the data array's data type will be 'float64' or if
#the partitions have data types 'int64' and 'int32' then the data
#array's data type will be 'int64'.
#
#**Examples**
#
#>>> type(pa.dtype)
#<type 'numpy.dtype'>
#>>> pa.dtype
#dtype('float64')
#
#'''
#        
#        return numpy_result_type(*(partition.data for partition in self.flat()))
#    #--- End: def

    # ----------------------------------------------------------------
    # Attribute: isscalar (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def isscalar(self):
        '''

True if the master array is a 0-d scalar array.

**Examples**

>>> pa.isscalar
True

'''
        return self.size == 1 and self[0].data.ndim == 0
    #--- End: def

    # ----------------------------------------------------------------
    # Attribute: ndim (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def ndim(self):
        '''

The number of partition dimensions in the partition array.

Not to be confused with the number of data array dimensions.

**Examples**

>>> pa.shape
[8, 4]
>>> pa.ndim
2

'''
        return len(self.dimensions)
    #--- End: def

    # ----------------------------------------------------------------
    # Attribute: shape (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def shape(self):
        '''

List of the partition arrays dimension sizes.

Not to be confused with the sizes of the data array's dimensions.

**Examples**

>>> pa.ndim
2
>>> pa.size
32
>>> pa.shape
[8, 4]

'''
        out = [len(self._list)]
        
        nest = self._list[0]
        while isinstance(nest, list):
            out.append(len(nest))
            nest = nest[0]
        #--- End: while

        return out
    #--- End: def

    # ----------------------------------------------------------------
    # Attribute: size (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def size(self):
        '''

The number of partitions in the partition array.

Not to be confused with the number of elements in the data array.

**Examples**

>>> pa.shape
[8, 4]
>>> pa.size
32

'''
        return long(reduce(mul, self.shape, 1))
    #--- End: def

    def add_partitions(self,
                       adimensions,
                       extra_boundaries, 
                       pdim,
                       existing_boundaries=None):
        '''

'''     

        def _add_partitions(partitions, existing_boundaries, extra_boundaries, 
                            dimensions, pdim, index, direction, dim2index,
                            adimensions):
            '''
            '''
            def _update_p(partitions, location, index, 
                          part, dim2index, direction):
                '''
'''                
                if isinstance(partitions, Partition):
                    partitions = [partitions]                        
                    
                elif not partitions._is_1d: # isinstance(partitions[0], Partition):
                    for nested_partitions in partitions:
                        # Recursive call
                        _update_p(nested_partitions,
                                  location, index,
                                  part, dim2index, direction)
                    #--- End: for
                    return
                #--- End: if

                for partition in partitions:
                    partition.location[index] = location
                    partition.shape[index]    = shape                       
                    partition.part = partition.new_part(part, 
                                                        dim2index, 
                                                        direction)
                #--- End: for
            #--- End: def

#            full = [slice(None)] * len(self.dimensions)
            full = [slice(None)] * len(adimensions)

            n_pdim = len(dimensions)

            if pdim == dimensions[0]:

                new_list = []
                extra_boundaries = extra_boundaries[:]
                x = extra_boundaries.pop(0)

                for p, r0, r1 in zip(partitions,
                                     existing_boundaries[:-1], 
                                     existing_boundaries[1:]):

                    if not r0 < x < r1:
                        if n_pdim == 1:
                            # p is a partition
                            new_list.append(p)
                        else:
                            # p is a partition array
                            new_list.append(p._list)

                        continue
                    #--- End: if

                    # Still here?

                    # Find the new extent of the original partition(s)
                    location    = (r0, x)
                    shape       = x - r0
                    part        = full[:]
                    part[index] = slice(0, shape)

                    # Create new partition(s) in place of the original
                    # ones(s) and set the location, shape and part
                    # attributes
                    new_p = p.copy()
                    _update_p(new_p,
                              location, index,
                              part, dim2index, direction)

                    # Append the new partition(s) to the new list
                    if n_pdim == 1:
                        new_list.append(new_p)
                    else:
                        new_list.append(new_p._list)

                    while x < r1:

                        # Find the extent of the new partition(s)
                        if not extra_boundaries:
                            # No more new boundaries, so the new
                            # partition(s) run to the end of the
                            # original partition(s) in which they lie.
                            location = (x, r1)
                        else:
                            # There are more new boundaries, so this
                            # new partition runs either to the next
                            # new boundary or to the end of the
                            # original partition, which comes first.
                            location = (x, min(extra_boundaries[0], r1))

                        shape       = location[1] - location[0]
                        offset      = x - r0
                        part        = full[:]
                        part[index] = slice(offset, offset+shape)

                        # Create the new partition(s) and set the
                        # location, shape and part attributes
                        new_p = p.copy()
                        _update_p(new_p,
                                  location, index,
                                  part, dim2index, direction)
                        
                        # Append the new partition(s) to the new list
                        if n_pdim == 1:
                            new_list.append(new_p)
                        else:
                            new_list.append(new_p._list)

                        # Move on to the next new boundary, if there
                        # is one
                        if not extra_boundaries:
                            break

                        x = extra_boundaries.pop(0)                        
                    #--- End: while                        
                #--- End: for 
                        
                # Update partitions array in place
                partitions[:] = new_list
                
            else:
                dimensions = dimensions[1:]
                if not dimensions:
                    return
                for nested_partitions in partitions:
                    # Recursive call
                    _add_partitions(nested_partitions,
                                    existing_boundaries, extra_boundaries,
                                    dimensions, pdim, index,
                                    direction, dim2index)
            #--- End: for
        #--- End: def
                    
        # If no extra boundaries have been provided, just return
        # without doing anything
        if not extra_boundaries:
            return

        # Find the existing boundaries if they haven't been provided
        if existing_boundaries is None:
            existing_boundaries = self.partition_boundaries(adimensions)[pdim]
            
#        dim2index = dict([(dim, i) for i, dim in enumerate(self.dimensions)])
        dim2index = dict([(dim, i) for i, dim in enumerate(adimensions)])

        _add_partitions(self, existing_boundaries, extra_boundaries, 
                        self.dimensions, pdim, adimensions.index(pdim), 
                        self.direction,
                        dim2index, adimensions)        
    #--- End: def
  
    def change_dimension_names(self, dim_name_map):
        '''

dim_name_map should be a dictionary which maps each dimension names in
self.dimensions to its new dimension name. E.g. {'dim0':'dim1',
'dim1':'dim0'}

'''
        # Check for a null dimension name mapping
        # (e.g. 'dim0'->'dim0', etc.) and return if there is nothing
        # to change. I think that this is worth it, on the grounds
        # that this is often be the case.
        map_is_null = True
        for dim0, dim1 in dim_name_map.iteritems():
            if dim0 != dim1:
                map_is_null = False
                break
        #--- End: for       
        if map_is_null:
            return
            
        # Still here? Then some dimension names need changing.

        # Dimension order
#        self.dimensions = [dim_name_map[dim] for dim in self.dimensions]
        
        # Partition dimensions
        self.dimensions = [dim_name_map[dim] for dim in self.dimensions]

        # Dimension directions
        if not self.isscalar:
            self.direction = dict([(dim_name_map[dim], value)
                                   for dim, value in self.direction.iteritems()]
                                  )

        # Partitions (Note that a partition may have dimensions which
        # are not in self.dimensions and that these must also be in
        # dim_name_map).
        for partition in self.flat():

            # Partition data dimension dimensions
            partition.dimensions = [dim_name_map[dim] for dim in partition.dimensions]
            
            # Partition data dimension directions
            direction = partition.direction
            if not partition.isscalar:
                partition.direction = dict(
                    [(dim_name_map[dim], value)
                     for dim, value in direction.iteritems()]
                    )
        #--- End: for
    #--- End: def

    def copy(self, _copy_attr=True):
        '''

Return a deep copy.

Do not set the `_copy_attr` parameter. It is for internal use only.

Equivalent to ``copy.deepcopy(pa)``.

:Returns:

    out :
        The deep copy.

**Examples**

>>> pa.copy()

''' 
        new = type(self)([])
        
        if _copy_attr:
            # Only copy the attributes once.
            new.dimensions = self.dimensions[:]
#            new.dimensions = self.dimensions[:] 
            new.direction  = copy(self.direction)
        #--- End: if

        if not self:
            return new

        if self._is_1d: #isinstance(self[0], Partition):
            for partition in self:
                new.append(partition.copy())
        else:
            for nested_array in self:
                # Recursive call
                pa = nested_array.copy(_copy_attr=False)
                new.append(pa)
        #--- End: if

        return new
    #--- End: def

    def expand_dims(self, pdim):
        '''

Insert a new size 1 partition dimension in place.

The new parition dimension is inserted at position 0.

:Parameters:

    pdim : str
        The name of the new partition dimension.

:Returns:

    None

**Examples**

>>> pa.dimensions
['dim0', 'dim1']
>>> pa.expand_dims('dim2')
>>> pa.dimensions
['dim2', 'dim0', 'dim1']

'''
        if not self.dimensions:
            self.dimensions = [pdim]
        else:
            self.dimensions.insert(0, pdim)        
            self._list = [self._list]
    #--- End: def
      
    def flat(self):
        '''

Return a flat iterator over the Partition objects in the partition
array.

:Returns:

    out : generator
        An iterator over the elements of the partition array.

**Examples**

>>> type(pa.flat())
<generator object flat at 0x145a0f0>
>>> for partition in pa.flat():
...     print partition.Units

'''
        if not self:
            return

        if self._is_1d: # isinstance(self[0], Partition):
            for partition in self:
                yield partition
        else:        
            for nested_partitions in self:
                for x in nested_partitions.flat():
                    yield x
    #--- End: def

    def info(self, attr):
        '''

'''
        if not self:
            return []

        out = []
        if self._is_1d: #isinstance(self[0], Partition):
            for partition in self:
                out.append(getattr(partition, attr, None))
        else:
            for nested_partition in self:
                # Recursive call
                out.append(nested_partition.info(attr))
        #--- End: if

        return out
    #--- End: def

#    def map(self, func, *args, **kwargs):
#        '''
#
#Return a list of the results of applying a function to each partition.
#
#:Parameters:
#
#    func : function
#
#    args, kwargs :
#
#:Returns:
#
#    out : list
#
#
#**Examples**
#
#>>> pa.map(min)
#[1, 9, -6]
#
#'''
#        out = []
#
#        conform_args = self.conform_args(revert_to_file=True)
#
#        for partition in self.flat():
#            array = partition.conform(**conform_args)
#            out.append(func(array, *args, **kwargs))
#            partition.close()
#        #--- End: for
#
#        return out
#    #--- End: def

    def partition_boundaries(self, adimensions):
        '''

'''            
        boundaries = {}

        zeros = [0] * len(self.dimensions)

        for j, pdim in enumerate(self.dimensions):

            first = self

            indices    = zeros[:]
            indices[j] = True

            while indices:
                index = indices.pop(0)
                if not index:
                    first = first[index]      # index is 0
                else:
                    break
            #--- End: while

            temp = []
            for element in first:
                for index in indices:
                    element = element[index]  # index is 0
                temp.append(element)
            #--- End: for
            first = temp
            # 'first' should now be list of Partition objects

#            i = self.dimensions.index(pdim)
            i = adimensions.index(pdim)

            b = [partition.location[i][1] for partition in first]
            b.insert(0, 0)

            boundaries[pdim] = b
        #--- End: for

        return boundaries
    #--- End: def

    def ravel(self):
        '''

Return a flattened partition array as a built-in list.

:Returns:

    out : list of Partitions
        A list containing the flattened partition array.

**Examples**

>>> x = partitions.ravel()
>>> type(x)
list

'''
        if not self:
            return []

        out = []
        if self._is_1d: # isinstance(self[0], Partition):
            out.extend(self)
        else:        
            for nested_partitions in self:
                out.append(nested_partitions.ravel())
        #--- End: if

        return out
    #--- End: def

    def rollaxis(self, axis, start=0):
        '''

Roll the specified partition dimension backwards,in place until it
lies in a given position.

This does not change the master array.

:Parameters:

    axis : int 
        The axis to roll backwards. The positions of the other axes do
        not change relative to one another.

    start : int, optional 
        The axis is rolled until it lies before this position.  The
        default, 0, results in a "complete" roll.

:Returns:

    None

**Examples**

>>> pa.rollaxis(2)
>>> pa.rollaxis(2, start=1)

'''
        if axis != start:
            self._list = numpy.rollaxis(numpy.array(self._list), 
                                        axis, start=start
                                        ).tolist()
            self.dimensions.insert(start, self.dimensions.pop(axis))
    #--- End: def
             
    def set_location_map(self, adimensions):
        '''

Recalculate the `location` attribute of each Partition object in the
partition array in place.

:Parameters:

    adimensions : list
    
**Examples**

>>> pa.set_location_map()

'''
        def _set_location_map(partitions, starts, stops, dimensions, 
                              index, level):
            '''
    '''
            if partitions._is_1d: # isinstance(partitions[0], Partition):
                if not dimensions:
                    inner_pdim = None
                else:
                    inner_pdim = dimensions[-1]

                for dim, start in starts.iteritems():
                    i = index[dim]
                    if dim == inner_pdim:
                        for partition in partitions:
                            stop                  = start + partition.shape[i]
                            partition.location[i] = (start, stop)
                            start                 = stop
                        #--- End: for
                        stops[dim] = 0
                    else:
                        stop     = start + partitions[0].shape[i]
                        location = (start, stop)
                        for partition in partitions:
                            partition.location[i] = location

                        if dim in dimensions:
                            stops[dim] = stop
                #--- End: for
            else:
                level += 1
                for nested_partitions in partitions:
                    starts[dimensions[level]] = stops[dimensions[level]]
                    for dim in dimensions[level+1:]:
                        starts[dim] = 0
                        stops[dim]  = 0
                    #--- End: for

                    # Recursive call
                    _set_location_map(nested_partitions,
                                      starts,
                                      stops,
                                      dimensions,
                                      index,
                                      level)
                #--- End: for
            #--- End: if
        #--- End: def

#        dimensions = self.dimensions

        starts = dict([(dim, 0) for dim in adimensions])
        stops  = starts.copy()

        index  = dict([(dim, i) for i, dim in enumerate(adimensions)])

        level  = -1

        _set_location_map(self, starts, stops, self.dimensions, index, level)
    #--- End: def

    def squeeze(self):
        '''

Remove all size 1 partition dimensions in place.

:Returns:

    None

**Examples**

>>> pa.dimensions
['dim0', 'dim1', 'dim2']
>>> pa._list
[[[<cf.partition.Partition object at 0x145daa0>,
   <cf.partition.Partition object at 0x145db90>]],
 [[<cf.partition.Partition object at 0x145dc08>,
   <cf.partition.Partition object at 0x145dd70>]]]
>>> pa.squeeze()
>>> pa.dimensions
['dim0', 'dim2']
>>> pa._list
[[<cf.partition.Partition object at 0x145daa0>,
  <cf.partition.Partition object at 0x145db90>],
 [<cf.partition.Partition object at 0x145dc08>,
  <cf.partition.Partition object at 0x145dd70>]]

'''        
        if self._is_1d:
            return

        axes1 = []
        axes2 = []

        pa = self
        for i, pdim in enumerate(self.dimensions):
            if len(pa) == 1:
                axes1.append(i)
            else:
                axes2.append(i)

            pa = pa[0]
        #--- End: for

        # Move the partition dimensions to be squeezed to the slowest
        # varying positions
        self.transpose(axes1+axes2)
  
        for pdim in axes1:
            self.dimensions.pop(0)
            if not isinstance(self._list[0], Partition):
                self._list = self._list[0]
        #--- End: for       
    #--- End: def

#    def to_disk(self):
#        '''
#
#Store each partition's data on disk in place.
#
#There is no change to partitions with data that are already on disk.
#
#
#:Returns:
#
#    None
#
#**Examples**
#
#>>> pa.to_disk()
#
#'''
#        for partition in self.flat():
#            if partition.in_memory:
#                partition.to_disk()
#    #--- End: def

    def transpose(self, axes):
        '''

Permute the partition dimensions of the partition array in place.

This does not change the master array.

:Parameters:

    axes : sequence of ints 
        Permute the axes according to the values given.

:Returns:

    None

**Examples**

>>> pa.transpose((2,0,1))

'''
        dimensions = self.dimensions
        if axes != range(len(dimensions)):
            self._list = numpy.transpose(numpy.array(self._list), 
                                         axes=axes
                                         ).tolist()
            self.dimensions = [dimensions[i] for i in axes]
    #--- End: def
             
#--- End: class


def empty(shape, **kwargs):
    '''

:Parameters:

    shape : sequence of ints

:Returns:

    out : PartitionArray

**Examples**

>>> pa = empty((12, 73))

'''
    def _pp(shape):
        '''
'''
        if not shape:
            return Partition()

        return [_pp(shape[1:]) for i in xrange(shape[0])]
    #--- End: def

    return PartitionArray(_pp(shape), **kwargs)
#--- End: def
